{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tutorial 2. ResNet50 on ImageNet (2012). \n",
    "\n",
    "\n",
    "In this tutorial, we will show \n",
    "\n",
    "- How to end-to-end train and compress a ResNet50 on ImageNet to reproduce the results shown in the paper."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1. Create OTO instance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from only_train_once import OTO\n",
    "import torchvision\n",
    "\n",
    "model = torchvision.models.resnet50(pretrained=True).cuda()\n",
    "dummy_input = torch.zeros(1, 3, 224, 224).cuda()\n",
    "oto = OTO(model=model, dummy_input=dummy_input)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### (Optional) Visualize the dependancy graph of DNN for ZIG partitions\n",
    "\n",
    "- Set `view` as `False` if no browser is accessiable.\n",
    "- Open the generated pdf file instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A ResNet_zig.gv.pdf will be generated to display the depandancy graph.\n",
    "oto.visualize_zigs(view=False)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2. Dataset Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "import os\n",
    "\n",
    "data_dir = \"/data/imagenet\" # Change to your own imagenet path\n",
    "train_dir = os.path.join(data_dir, 'train')\n",
    "test_dir = os.path.join(data_dir, 'val')\n",
    "batch_size = 128\n",
    "\n",
    "transform_train = transforms.Compose([\n",
    "    transforms.RandomResizedCrop(224),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ColorJitter(\n",
    "        brightness=0.4,\n",
    "        contrast=0.4,\n",
    "        saturation=0.4,\n",
    "        hue=0.2),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),])\n",
    "\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.Resize(256),\n",
    "    transforms.CenterCrop(224),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),])\n",
    "\n",
    "trainset = torchvision.datasets.ImageFolder(root=train_dir, transform=transform_train)\n",
    "testset = torchvision.datasets.ImageFolder(root=test_dir, transform=transform_test)\n",
    "\n",
    "trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)\n",
    "testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=8)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 3. Setup DHSPG optimizer\n",
    "\n",
    "The following main hyperparameters need to be taken care.\n",
    "\n",
    "- `variant`: The optimizer that is used for training the baseline full model. \n",
    "- `lr`: The initial learning rate.\n",
    "- `weight_decay`: Weight decay as standard DNN optimization.\n",
    "- `target_group_sparsity`: The target group sparsity, typically higher group sparsity refers to more FLOPs and model size reduction, meanwhile may regress model performance more.\n",
    "- `start_pruning_steps`: The number of steps that start to prune. \n",
    "- `first_momentum`: The first-order momentum.\n",
    "- `epsilon`: The cofficient [0, 1) to control the aggresiveness of group sparsity exploration. Higher value means more aggressive group sparsity exploration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = oto.dhspg(\n",
    "    variant='sgd', \n",
    "    lr=0.1, \n",
    "    target_group_sparsity=0.4,\n",
    "    weight_decay=0.0, # Some training set it as 1e-4.\n",
    "    first_momentum=0.9,\n",
    "    start_pruning_steps=15 * len(trainloader), # start pruning after 15 epochs. Start pruning at initialization stage.\n",
    "    epsilon=0.95)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 4. Train ResNet50 as normal."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add functions related to `label smoothing` and `mix-up`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "\n",
    "def one_hot(y, num_classes, smoothing_eps=None):\n",
    "    if smoothing_eps is None:\n",
    "        one_hot_y = F.one_hot(y, num_classes).float()\n",
    "        return one_hot_y\n",
    "    else:\n",
    "        one_hot_y = F.one_hot(y, num_classes).float()\n",
    "        v1 = 1 - smoothing_eps + smoothing_eps / float(num_classes)\n",
    "        v0 = smoothing_eps / float(num_classes)\n",
    "        new_y = one_hot_y * (v1 - v0) + v0\n",
    "        return new_y\n",
    "\n",
    "def cross_entropy_onehot_target(logit, target):\n",
    "    # target must be one-hot format!!\n",
    "    prob_logit = F.log_softmax(logit, dim=1)\n",
    "    loss = -(target * prob_logit).sum(dim=1).mean()\n",
    "    return loss\n",
    "\n",
    "def mixup_func(input, target, alpha=0.2):\n",
    "    gamma = np.random.beta(alpha, alpha)\n",
    "    # target is onehot format!\n",
    "    perm = torch.randperm(input.size(0))\n",
    "    perm_input = input[perm]\n",
    "    perm_target = target[perm]\n",
    "    return input.mul_(gamma).add_(1 - gamma, perm_input), target.mul_(gamma).add_(1 - gamma, perm_target)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Start official training and compression.\n",
    "\n",
    "Under $40\\%$ group sparsity, the top-1 accuracy could reach $75.2\\%$ on the specific run.\n",
    "\n",
    "During our experiments, the top-1 accuracy is $(75.2\\pm0.1)\\%$ upon different random seeds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.utils import check_accuracy\n",
    "\n",
    "label_smooth = True\n",
    "mix_up = True\n",
    "train_time = 2 # Mix-up requires longer training time for better convergence.\n",
    "ckpt_dir = './' # Checkpoint save directory\n",
    "\n",
    "max_epoch = 120\n",
    "if not label_smooth:\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "else:\n",
    "    criterion = cross_entropy_onehot_target\n",
    "    \n",
    "# Every 30 epochs, decay lr by 10.0\n",
    "lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) \n",
    "num_classes = 1000\n",
    "\n",
    "best_acc_1 = 0.0\n",
    "\n",
    "for epoch in range(max_epoch):\n",
    "    f_avg_val = 0.0\n",
    "    model.train()\n",
    "    lr_scheduler.step()  \n",
    "    for t in range(train_time):\n",
    "        for X, y in trainloader:\n",
    "            X = X.cuda()\n",
    "            y = y.cuda()\n",
    "            with torch.no_grad():\n",
    "                if label_smooth and not mix_up:\n",
    "                    y = one_hot(y, num_classes=num_classes, smoothing_eps=0.1)\n",
    "\n",
    "                if not label_smooth and mix_up:\n",
    "                    y = one_hot(y, num_classes=num_classes)\n",
    "                    X, y = mixup_func(X, y)\n",
    "                \n",
    "                if mix_up and label_smooth:\n",
    "                    y = one_hot(y, num_classes=num_classes, smoothing_eps=0.1)\n",
    "                    X, y = mixup_func(X, y)\n",
    "                    \n",
    "            y_pred = model.forward(X)\n",
    "            f = criterion(y_pred, y)\n",
    "            optimizer.zero_grad()\n",
    "            f.backward()\n",
    "            f_avg_val += f\n",
    "            optimizer.step()\n",
    "        group_sparsity, omega = optimizer.compute_group_sparsity_omega()\n",
    "        norm_x_pm, norm_x_npm, num_groups_pm, num_groups_npm = optimizer.compute_norm_group_partitions()\n",
    "        accuracy1, accuracy5 = check_accuracy(model, testloader)\n",
    "        f_avg_val = f_avg_val.cpu().item() / len(trainloader)\n",
    "        print(\"Epoch: {ep}, loss: {f:.2f}, omega:{om:.2f}, group_sparsity: {gs:.2f}, acc1: {acc:.4f}\"\\\n",
    "            .format(ep=epoch, f=f_avg_val, om=omega, gs=group_sparsity, acc=accuracy1))\n",
    "\n",
    "        if accuracy1 > best_acc_1:\n",
    "            best_acc_1 = accuracy1\n",
    "            torch.save(model, os.path.join(ckpt_dir, 'best_epoch_' + str(epoch) + '_' + str(t) + '.pt'))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 5. Get compressed model in ONNX format\n",
    "\n",
    "By default, OTO will compress the last checkpoint.\n",
    "\n",
    "If we want to compress another checkpoint, need to reinitialize OTO then compress\n",
    "\n",
    "    oto = OTO(model=torch.load(ckpt_path), dummy_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A ResNet_compressed.onnx will be generated. \n",
    "oto.compress()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (Optional) Compute FLOPs and number of parameters before and after OTO training\n",
    "\n",
    "The compressed ResNet50 under 40% group sparsity reduces FLOPs by 61.5% and parameters by 50.6%."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Full FLOPs (M): 4089.18. Compressed FLOPs (M): 1574.71. Reduction Ratio: 0.6149\n",
      "Full # Params: 25557032. Compressed # Params: 12626516. Reduction Ratio: 0.5059\n"
     ]
    }
   ],
   "source": [
    "full_flops = oto.compute_flops()\n",
    "compressed_flops = oto.compute_flops(compressed=True)\n",
    "full_num_params = oto.compute_num_params()\n",
    "compressed_num_params = oto.compute_num_params(compressed=True)\n",
    "\n",
    "print(\"Full FLOPs (M): {f_flops:.2f}. Compressed FLOPs (M): {c_flops:.2f}. Reduction Ratio: {f_ratio:.4f}\"\\\n",
    "      .format(f_flops=full_flops, c_flops=compressed_flops, f_ratio=1 - compressed_flops/full_flops))\n",
    "print(\"Full # Params: {f_params}. Compressed # Params: {c_params}. Reduction Ratio: {f_ratio:.4f}\"\\\n",
    "      .format(f_params=full_num_params, c_params=compressed_num_params, f_ratio=1 - compressed_num_params/full_num_params))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (Optional) Check the output difference between full model and compressed model.\n",
    "\n",
    "#### Both full and compressed model should return the exact same output given the same input upon floating error.\n",
    "#### The maximum deviation should be up to `1e-5` which is negligible. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum output difference:\n",
      "3.33786e-06\n"
     ]
    }
   ],
   "source": [
    "import onnxruntime as ort\n",
    "full_ort_sess = ort.InferenceSession(oto.full_model_path)\n",
    "compress_ort_sess = ort.InferenceSession(oto.compressed_model_path)\n",
    "\n",
    "fake_input = torch.rand(1, 3, 224, 224)\n",
    "\n",
    "full_output = full_ort_sess.run(None, {'input.1': fake_input.numpy()})[0]\n",
    "compress_output = compress_ort_sess.run(None, {'input.1': fake_input.numpy()})[0]\n",
    "print(\"Maximum output difference:\")\n",
    "print(np.max(np.abs(full_output - compress_output)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
